# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import print_function
import os
import sys
import time
import numpy as np
import faiss
import argparse
import datasets
from datasets import sanitize
import neighbor_codec

######################################################
# Command-line parsing
######################################################


parser = argparse.ArgumentParser()

def aa(*args, **kwargs):
    group.add_argument(*args, **kwargs)

group = parser.add_argument_group('dataset options')

aa('--db', default='deep1M', help='dataset')
aa( '--compute_gt', default=False, action='store_true',
    help='compute and store the groundtruth')

group = parser.add_argument_group('index consturction')

aa('--indexkey', default='HNSW32', help='index_factory type')
aa('--efConstruction', default=200, type=int,
   help='HNSW construction factor')
aa('--M0', default=-1, type=int, help='size of base level')
aa('--maxtrain', default=256 * 256, type=int,
   help='maximum number of training points')
aa('--indexfile', default='', help='file to read or write index from')
aa('--add_bs', default=-1, type=int,
   help='add elements index by batches of this size')
aa('--link_singletons', default=False, action='store_true',
   help='do a pass to link in the singletons')

group = parser.add_argument_group(
    'searching (reconstruct_from_neighbors options)')

aa('--beta_centroids', default='',
   help='file with codebook')
aa('--neigh_recons_codes', default='',
   help='file with codes for reconstruction')
aa('--beta_ntrain', default=250000, type=int, help='')
aa('--beta_k', default=256, type=int, help='beta codebook size')
aa('--beta_nsq', default=1, type=int, help='number of beta sub-vectors')
aa('--beta_niter', default=10, type=int, help='')
aa('--k_reorder', default='-1', help='')

group = parser.add_argument_group('searching')

aa('--k', default=100, type=int, help='nb of nearest neighbors')
aa('--exhaustive', default=False, action='store_true',
    help='report the exhaustive search topline')
aa('--searchthreads', default=-1, type=int,
   help='nb of threads to use at search time')
aa('--efSearch', default='', type=str,
   help='comma-separated values of efSearch to try')

args = parser.parse_args()

print("args:", args)


######################################################
# Load dataset
######################################################

xt, xb, xq, gt = datasets.load_data(
    dataset=args.db, compute_gt=args.compute_gt)

nq, d = xq.shape
nb, d = xb.shape


######################################################
# Make index
######################################################

if os.path.exists(args.indexfile):

    print("reading", args.indexfile)
    index = faiss.read_index(args.indexfile)

    if isinstance(index, faiss.IndexPreTransform):
        index_hnsw = faiss.downcast_index(index.index)
        vec_transform = index.chain.at(0).apply_py
    else:
        index_hnsw = index
        vec_transform = lambda x:x

    hnsw = index_hnsw.hnsw
    hnsw_stats = faiss.cvar.hnsw_stats

else:

    print("build index, key=", args.indexkey)

    index = faiss.index_factory(d, args.indexkey)

    if isinstance(index, faiss.IndexPreTransform):
        index_hnsw = faiss.downcast_index(index.index)
        vec_transform = index.chain.at(0).apply_py
    else:
        index_hnsw = index
        vec_transform = lambda x:x

    hnsw = index_hnsw.hnsw
    hnsw.efConstruction = args.efConstruction
    hnsw_stats = faiss.cvar.hnsw_stats
    index.verbose = True
    index_hnsw.verbose = True
    index_hnsw.storage.verbose = True

    if args.M0 != -1:
        print("set level 0 nb of neighbors to", args.M0)
        hnsw.set_nb_neighbors(0, args.M0)

    xt2 = sanitize(xt[:args.maxtrain])
    assert np.all(np.isfinite(xt2))

    print("train, size", xt.shape)
    t0 = time.time()
    index.train(xt2)
    print("  train in %.3f s" % (time.time() - t0))

    print("adding")
    t0 = time.time()
    if args.add_bs == -1:
        index.add(sanitize(xb))
    else:
        for i0 in range(0, nb, args.add_bs):
            i1 = min(nb, i0 + args.add_bs)
            print("  adding %d:%d / %d" % (i0, i1, nb))
            index.add(sanitize(xb[i0:i1]))

    print("  add in %.3f s" % (time.time() - t0))
    print("storing", args.indexfile)
    faiss.write_index(index, args.indexfile)


######################################################
# Train beta centroids and encode dataset
######################################################

if args.beta_centroids:
    print("reordering links")
    index_hnsw.reorder_links()

    if os.path.exists(args.beta_centroids):
        print("load", args.beta_centroids)
        beta_centroids = np.load(args.beta_centroids)
        nsq, k, M1 = beta_centroids.shape
        assert M1 == hnsw.nb_neighbors(0) + 1

        rfn = faiss.ReconstructFromNeighbors(index_hnsw, k, nsq)
    else:
        print("train beta centroids")
        rfn = faiss.ReconstructFromNeighbors(
            index_hnsw, args.beta_k, args.beta_nsq)

        xb_full = vec_transform(sanitize(xb[:args.beta_ntrain]))

        beta_centroids = neighbor_codec.train_beta_codebook(
            rfn, xb_full, niter=args.beta_niter)

        print("  storing", args.beta_centroids)
        np.save(args.beta_centroids, beta_centroids)


    faiss.copy_array_to_vector(beta_centroids.ravel(),
                               rfn.codebook)
    index_hnsw.reconstruct_from_neighbors = rfn

    if rfn.k == 1:
        pass     # no codes to take care of
    elif os.path.exists(args.neigh_recons_codes):
        print("loading neigh codes", args.neigh_recons_codes)
        codes = np.load(args.neigh_recons_codes)
        assert codes.size == rfn.code_size * index.ntotal
        faiss.copy_array_to_vector(codes.astype('uint8'),
                                   rfn.codes)
        rfn.ntotal = index.ntotal
    else:
        print("encoding neigh codes")
        t0 = time.time()

        bs = 1000000 if args.add_bs == -1 else args.add_bs

        for i0 in range(0, nb, bs):
            i1 = min(i0 + bs, nb)
            print("   encode %d:%d / %d [%.3f s]\r" % (
                i0, i1, nb, time.time() - t0), end=' ')
            sys.stdout.flush()
            xbatch = vec_transform(sanitize(xb[i0:i1]))
            rfn.add_codes(i1 - i0, faiss.swig_ptr(xbatch))
        print()

        print("storing %s" % args.neigh_recons_codes)
        codes = faiss.vector_to_array(rfn.codes)
        np.save(args.neigh_recons_codes, codes)

######################################################
# Exhaustive evaluation
######################################################

if args.exhaustive:
    print("exhaustive evaluation")
    xq_tr = vec_transform(sanitize(xq))
    index2 = faiss.IndexFlatL2(index_hnsw.d)
    accu_recons_error = 0.0

    if faiss.get_num_gpus() > 0:
        print("do eval on GPU")
        co = faiss.GpuMultipleClonerOptions()
        co.shard = False
        index2 = faiss.index_cpu_to_all_gpus(index2, co)

    # process in batches in case the dataset does not fit in RAM
    rh = datasets.ResultHeap(xq_tr.shape[0], 100)
    t0 = time.time()
    bs = 500000
    for i0 in range(0, nb, bs):
        i1 = min(nb, i0 + bs)
        print('  handling batch %d:%d' % (i0, i1))

        xb_recons = np.empty(
            (i1 - i0, index_hnsw.d), dtype='float32')
        rfn.reconstruct_n(i0, i1 - i0, faiss.swig_ptr(xb_recons))

        accu_recons_error += (
            (vec_transform(sanitize(xb[i0:i1])) -
             xb_recons)**2).sum()

        index2.reset()
        index2.add(xb_recons)
        D, I = index2.search(xq_tr, 100)
        rh.add_batch_result(D, I, i0)

    rh.finalize()
    del index2
    t1 = time.time()
    print("done in %.3f s" % (t1 - t0))
    print("total reconstruction error: ", accu_recons_error)
    print("eval retrieval:")
    datasets.evaluate_DI(rh.D, rh.I, gt)


def get_neighbors(hnsw, i, level):
    " list the neighbors for node i at level "
    assert i < hnsw.levels.size()
    assert level < hnsw.levels.at(i)
    be = np.empty(2, 'uint64')
    hnsw.neighbor_range(i, level, faiss.swig_ptr(be), faiss.swig_ptr(be[1:]))
    return [hnsw.neighbors.at(j) for j in range(be[0], be[1])]


#############################################################
# Index is ready
#############################################################

xq = sanitize(xq)

if args.searchthreads != -1:
    print("Setting nb of threads to", args.searchthreads)
    faiss.omp_set_num_threads(args.searchthreads)


if gt is None:
    print("no valid groundtruth -- exit")
    sys.exit()


k_reorders = [int(x) for x in args.k_reorder.split(',')]
efSearchs = [int(x) for x in args.efSearch.split(',')]


for k_reorder in k_reorders:

    if index_hnsw.reconstruct_from_neighbors:
        print("setting k_reorder=%d" % k_reorder)
        index_hnsw.reconstruct_from_neighbors.k_reorder = k_reorder

    for efSearch in efSearchs:
        print("efSearch=%-4d" % efSearch, end=' ')
        hnsw.efSearch = efSearch
        hnsw_stats.reset()
        datasets.evaluate(xq, gt, index, k=args.k, endl=False)

        print("ndis %d nreorder %d" % (hnsw_stats.ndis, hnsw_stats.nreorder))
